import torchvision
import torch
import torch.nn as nn
from torch.utils import data
from glob import glob
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
from PIL import Image
import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using device:", device)

# 加载resnet18模型
# resnet18 = models.resnet18(pretrained=False)
# resnet101 = models.resnet101(pretrained=True)
densenet201 = models.densenet201(pretrained=True)
# squeezenet_1_1 = models.squeezenet1_1(pretrained=True)

# 获取resnet18最后一层输出，输出为512维,最后一层本来是用作 分类的，原始网络分为1000类
# 用 softmax函数或者 fully connected 函数，但是用 nn.identtiy() 函数把最后一层替换掉，相当于得到分类之前的特征！
# Identity模块，它将输入直接传递给输出，而不会对输入进行任何变换。
# resnet18.fc = nn.Identity()
# resnet101.fc = nn.Identity()
densenet201.fc = nn.Identity()


# squeezenet_1_1.classifier = nn.Sequential(
#     nn.Dropout(p=0.5, inplace=False),
#     nn.Conv2d(512, 1000, kernel_size=(1, 1), stride=(1, 1))
# )
info_labels = []


def get_label(x):
    if isinstance(x, str):
        if x not in info_labels:
            info_labels.append(x)
        return info_labels.index(x)
    else:
        return info_labels[x]


class MyDataSet(data.Dataset):

    def __init__(self, data_path, transform):
        self.labels = []
        self.imgs = []
        self.transform = transform

        data_list = glob(f"{data_path}/*/*.jpg")
        for img_path in data_list:
            self.imgs.append(img_path)
            label = img_path.split("\\")[-2]
            label = get_label(label)
            self.labels.append(label)
            # print(img_path)
            # break

    def __getitem__(self, item):
        img_path = self.imgs[item]
        img = Image.open(img_path)
        img_data = self.transform(img)
        label = self.labels[item]
        return img_data, np.array(label)

    def __len__(self):
        return len(self.imgs)


# 构建新的网络，将resnet18的输出作为输入
# 构建新的网络，将resnet18的输出作为输入
class Attention(nn.Module):
    def __init__(self, in_channels):
        super(Attention, self).__init__()
        # 输入及输出都为3通道，不改变原始图片通道数
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)

    def forward(self, x):
        a = F.relu(self.conv(x))
        a = F.softmax(a.view(a.size(0), -1), dim=1).view_as(a)
        x = x * a
        return x


# 构建新的网络，将resnet18的输出作为输入
class Net(nn.Module):
    def __init__(self, num_class: int):
        super(Net, self).__init__()
        self.num_class = num_class
        # 注意力，用于区分输入图片重要的部分
        self.attention1 = Attention(3)
        # self.resnet18 = resnet18
        # self.resnet101 = resnet101
        self.densenet201 = densenet201
        # self.squeezenet_1_1 = squeezenet_1_1
        # self.fc0 = nn.Linear(9000, 1000)
        self.fc1 = nn.Linear(1000, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, 64)
        self.fc4 = nn.Linear(64, num_class)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.attention1(x)
        # x = self.resnet18(x)
        # x = self.resnet101(x)
        x = self.densenet201(x)
        # x = self.squeezenet_1_1(x)
        # x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.softmax(x)
        x = x.view(-1, 2)
        return x


# 实例化网络
model = Net(len(info_labels))
# 将模型放入GPU
model = model.to(device)

# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 定义优化器，添加l2正则化
optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=0)


# 加载数据集
# 创建一个transform对象
def rgb2bgr(image):
    image = np.array(image)[:, :, ::-1]
    image = Image.fromarray(np.uint8(image))
    return image


transform_list = [
    # transforms.Resize((224, 224)),
    # #transforms.ColorJitter的参数主要有：brightness，contrast，saturation和hue。
    # # brightness：用于调整图像的亮度，取值范围为[0, 1]，如果设置为0.5，则图像的亮度会随机增加或减少0-50%。
    # # contrast：用于调整图像的对比度，取值范围为[0, 1]，如果设置为0.5，则图像的对比度会随机增加或减少0-50%。
    # # saturation：用于调整图像的饱和度，取值范围为[0, 1]，如果设置为0.5，则图像的饱和度会随机增加或减少0-50%。
    # # hue：用于调整图像的色调，取值范围为[-0.5, 0.5]，如果设置为0.5，则图像的色调会随机增加或减少0-50%。
    #  transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=1),
    # # 原始图像在旋转-90度到90度也能被识别
    transforms.RandomRotation(degrees=(-45, 45)),
    # # 抵抗模糊的影响:参数3是模糊程度的参数，它表示高斯滤波器的核的大小
    transforms.GaussianBlur(3),
    # # size：表示裁剪出来的图像的大小，即输出图像的大小。
    # # scale：表示裁剪框的大小与原始图像的比例，范围是0.08到1.0，0.08表示原始图像的最小尺寸，1.0表示原始图像的最大尺寸。
    # # ratio：表示裁剪框的宽高比。
    # # interpolation：表示插值方法，2表示双线性插值。
    # torchvision.transforms.RandomResizedCrop((112, 112), scale=(0.8, 1), ratio=(1, 1), interpolation=2),
    # #0.5的概率水平翻转
    transforms.RandomHorizontalFlip(p=0.5),
    # degrees：可从中选择的度数范围。如果为非零数字，旋转角度从（-degrees,+degress),或者可设置为（min,max)
    # translate:水平和垂直平移的最大绝对偏移量。长度为2的元组，数值在(0,1)之间，dx在（-w*a,w*a)，dy在（-h*b,h*b)
    # scale:比例因子区间，例如（a，b），则从范围a<=比例<=b中随机采样比例。默认情况下，将保留原始比例。
    # shear:可从中选择的度数范围。放射变换的角度，若为 (a, b)，x 轴在 (-a, a) 之间随机选择错切角度，y 轴在 (-b, b) 之间随机选择错切角度,用灰色填充
    transforms.RandomAffine(degrees=(-45, 45), translate=(0.1, 0.3), scale=(0.8, 1.2), shear=(10, 10),
                            fill=(114, 114, 114))
]

transformR = transforms.Compose([transforms.RandomApply(transform_list, 0.1)])
transform = transforms.Compose([
    transformR,
    transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),
    # rgb转bgr
    torchvision.transforms.Lambda(rgb2bgr),
    torchvision.transforms.Resize((64, 64)),
    # 入的图片为PIL image 或者 numpy.nadrry格式的图片，其shape为（HxWxC）数值范围在[0,255],转换之后shape为（CxHxw）,数值范围在[0,1]
    transforms.ToTensor(),
    # 进行归一化和标准化，Imagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，
    # 因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。
    transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]),
    # #这行代码表示使用transforms.RandomErasing函数，以概率p=1，在图像上随机选择一个尺寸为scale=(0.02, 0.33)，长宽比为ratio=(1, 1)的区域，
    # #进行随机像素值的遮盖，只能对tensor操作：
    # transforms.RandomErasing(p=0.1, scale=(0.02, 0.2), ratio=(1, 1), value='random')
])
transform2 = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ColorJitter(brightness=1, contrast=1, saturation=1, hue=0.5),
    # rgb转bgr
    torchvision.transforms.Lambda(rgb2bgr),
    # 入的图片为PIL image 或者 numpy.nadrry格式的图片，其shape为（HxWxC）数值范围在[0,255],转换之后shape为（CxHxw）,数值范围在[0,1]
    transforms.ToTensor(),
    # 进行归一化和标准化，Imagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，
    # 因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。
    transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
print("开始读取数据集")
# train_dataset = torchvision.datasets.ImageFolder(r'./data/train', transform=transform)
train_dataset = MyDataSet(r'./data/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
# test_dataset = torchvision.datasets.ImageFolder(r'./data/test', transform=transform2)
test_dataset = MyDataSet(r'./data/test', transform=transform2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=True)
print("读取数据集完成")

# 绘制训练及测试集迭代曲线
# 记录训练集准确率
train_acc = []
# 记录测试集准确率
test_acc = []
for epoch in range(50):
    print(f"{epoch} 开始训练")
    running_loss = 0.0
    # [(0, data1), (1, data2), (2, data3), ...]
    for i, data in enumerate(train_loader, 0):
        # 获取输入
        inputs, labels = data
        labels = torch.tensor(labels).long()
        inputs, labels = inputs.to(device), labels.to(device)
        # 梯度清零
        optimizer.zero_grad()
        # forward + backward
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        # 更新参数
        optimizer.step()
        # 打印log信息
        # loss 是一个scalar,需要使用loss.item()来获取数值，不能使用loss[0]
        running_loss += loss.item()

        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss))
        running_loss = 0.0
        # 在每次训练完成后，使用测试集进行测试
        correct = 0
        total = 0
        with torch.no_grad():
            for i2, data2 in enumerate(test_loader):
                # 控制测试集的数量
                if i2 > 5:
                    break
                images, labels = data2
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        test_acc.append(100 * correct / total)
        print('Accuracy of the network on the test images: %.3f,now max acc is %.3f' % (
            100 * correct / total, max(test_acc)))
        # 保存测试集上准确率最高的模型
        if 100 * correct / total == max(test_acc):
            if not os.path.exists(r'./result'):
                os.makedirs(r'./result')
            # if max(test_acc) > 95:
    # save_name = f"./result/epoch{epoch}bestmodel" + "%.3f" % max(test_acc) + ".pth"
    save_name = f"./result/epoch{epoch}_model_acc{max(test_acc)}.pth"
    torch.save(model.state_dict(), save_name)

print("最大准确度：", max(test_acc))

# # 绘制训练及测试集迭代曲线
# plt.plot(range(len(train_acc)), train_acc, color='blue', label='Train Acc')
# plt.plot(range(len(test_acc)), test_acc, color='red', label='Test Acc')
# plt.legend()
# plt.title("Accuracy Curve")
# plt.show()


# 测试的脚本
# 在验证集上进行测试
# savename = "./result/bestmodel" + "%.3f" % max(test_acc) + ".pth"
# model.load_state_dict(torch.load(savename))
#
# val_dataset = torchvision.datasets.ImageFolder(r'D:\eyeDataSet\validate', transform=transform2)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)
#
# # 验证集上的表现情况：
# correct = 0
# total = 0
# with torch.no_grad():
#     for data in val_loader:
#         images, labels = data
#         images, labels = images.cuda(), labels.cuda()
#         outputs = model(images)
#         _, predicted = torch.max(outputs.data, 1)
#         total += labels.size(0)
#         correct += (predicted == labels).sum().item()
#
# print('Accuracy of the network on the validation images: %.3f %%' % (100 * correct / total))
